-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[not for land] online fp8 quant with streaming weight post-processing #29196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a proof-of-concept for online FP8 quantization with streaming weight post-processing. The approach is clever, patching the weight loader to trigger post-processing as soon as a weight tensor is fully loaded. This can help reduce peak memory usage during model loading.
My main feedback is to improve the robustness of the state management. The flag _already_called_process_weights_after_loading is currently stored on the Fp8LinearMethod instance. While this works with the current code structure, it's fragile. Attaching this state to the layer object instead would make the implementation more robust against future changes, such as instance reuse for optimization. I've added specific comments with code suggestions to address this.
| del param._loaded_numel | ||
| # Prevent the usual `process_weights_after_loading` call from doing | ||
| # anything | ||
| self._already_called_process_weights_after_loading = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Storing _already_called_process_weights_after_loading on self (the Fp8LinearMethod instance) makes the design fragile. Although a new instance is currently created for each layer, this might change in the future (e.g., for optimization), which could lead to this flag persisting incorrectly across different layers.
To make this more robust, this state should be attached to the layer object, which is guaranteed to be unique. This change should be made in conjunction with the corresponding check in process_weights_after_loading.
| self._already_called_process_weights_after_loading = True | |
| layer._already_called_process_weights_after_loading = True |
| layer.register_parameter("input_scale", None) | ||
|
|
||
| def process_weights_after_loading(self, layer: Module) -> None: | ||
| if getattr(self, "_already_called_process_weights_after_loading", False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make the state management more robust and in conjunction with the suggested change for setting this flag, this check should be on the layer object instead of self.
| if getattr(self, "_already_called_process_weights_after_loading", False): | |
| if getattr(layer, "_already_called_process_weights_after_loading", False): |
Summary: not for land, just a demo 1. during weight loading, keep track of how many elements we have loaded 2. when we have loaded all the elements, call post-processing can be used to call weight post-processing in a streaming fashion to minimize GPU memory usage. Will only work if we can assume we only load each weight chunk once. Test Plan: tested locally with facebook/opt-125m and `fp8` online quantization Reviewers: Subscribers: Tasks: Tags: Signed-off-by: <[email protected]>
5326892 to
9583e3b
Compare
Summary:
not for land, just a demo
can be used to call weight post-processing in a streaming fashion to minimize GPU memory usage. Will only work if we can assume we only load each weight chunk once.
Test Plan:
tested locally with facebook/opt-125m and
fp8online quantizationReviewers:
Subscribers:
Tasks:
Tags:
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.